import torch.nn as nn
import torch.nn.functional as F


# 判别器网络
class Discriminator(nn.Module):

    def __init__(self, image_size: int, hidden_size: int):
        super(Discriminator, self).__init__()
        self.linear1 = nn.Linear(image_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, 1)

    def forward(self, x):
        out = self.linear1(x)
        out = F.leaky_relu(out, negative_slope=0.2, inplace=True)
        out = self.linear2(out)
        out = F.leaky_relu(out, negative_slope=0.2, inplace=True)
        out = self.linear3(out)
        return F.sigmoid(out)


# 生成器网络
class Generator(nn.Module):
    def __init__(self, image_size: int, latent_size: int, hidden_size: int):
        super(Generator, self).__init__()
        self.linear1 = nn.Linear(latent_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, image_size)

    def forward(self, x):
        out = self.linear1(x)
        out = F.relu(out)
        out = self.linear2(out)
        out = F.relu(out)
        out = self.linear3(out)
        return F.tanh(out)
